"""Utilities for running experiments."""
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from botorch.acquisition import GenericMCObjective
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.analytic import (
    ConstrainedExpectedImprovement,
    ExpectedImprovement,
    LogConstrainedExpectedImprovement,
    LogExpectedImprovement,
    LogProbabilityOfImprovement,
    ProbabilityOfImprovement,
)
from botorch.acquisition.joint_entropy_search import qJointEntropySearch
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.acquisition.max_value_entropy_search import qLowerBoundMaxValueEntropy
from botorch.acquisition.monte_carlo import (
    qExpectedImprovement,
    qProbabilityOfImprovement,
)
from botorch.acquisition.multi_objective.analytic import ExpectedHypervolumeImprovement
from botorch.acquisition.multi_objective.monte_carlo import (
    qExpectedHypervolumeImprovement,
    qNoisyExpectedHypervolumeImprovement,
)
from botorch.acquisition.objective import ConstrainedMCObjective
from botorch.acquisition.utils import get_optimal_samples
from botorch.exceptions.errors import UnsupportedError
from botorch.models.gp_regression import FixedNoiseGP, SingleTaskGP
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.transforms.outcome import Bilog, Standardize
from botorch.optim.utils import _filter_kwargs
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.test_functions.base import (
    BaseTestProblem,
    ConstrainedBaseTestProblem,
    MultiObjectiveTestProblem,
)
from botorch.test_functions.multi_objective import BraninCurrin, DTLZ2, ZDT1, ZDT2, ZDT3
from botorch.test_functions.synthetic import (
    Ackley,
    Branin,
    Hartmann,
    Levy,
    Michalewicz,
    PressureVesselDesign,
    Rastrigin,
    Shekel,
    SpeedReducer,
    StyblinskiTang,
    SumOfSquares,
    TensionCompressionString,
    WeldedBeamSO,
)
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
    FastNondominatedPartitioning,
)
from botorch.utils.sampling import draw_sobol_samples
from botorch.utils.transforms import unnormalize
from .logehvi import LogExpectedHypervolumeImprovement
from .qlogehvi import (
    qLogExpectedHypervolumeImprovement,
)
from .qlogei import (
    qLogExpectedImprovement,
    qLogProbabilityOfImprovement,
)
from gpytorch.kernels import MaternKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.mlls import SumMarginalLogLikelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood

from torch import Tensor


def eval_problem(
    X: Tensor,
    base_function: BaseTestProblem,
    noise_se: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
    """Evaluate objectives.

    Args:
        X: A `n x d`-dim tensor of inputs normalized to the unit cube.
        base_function: A test problem.
        noise_se: A tensor of observation noise standard deviations.

    Returns:
        A two-element tuple containing the noiseless and noisy objective
        values.
    """
    X = unnormalize(X, base_function.bounds)
    Y_true = base_function(X)
    if Y_true.ndim < X.ndim:  # adding output dimension
        Y_true = Y_true.unsqueeze(-1)
    if isinstance(base_function, ConstrainedBaseTestProblem):
        slack_true = base_function.evaluate_slack(X)
        # applying bi-log to some constrained problems for modeling purposes
        if (
            isinstance(base_function, TensionCompressionString)
            or isinstance(base_function, WeldedBeamSO)
            or isinstance(base_function, SpeedReducer)
        ):
            print("Applying Bilog transform.")
            slack_true, _ = Bilog()(slack_true)
        Y_true = torch.cat([Y_true, slack_true], dim=-1)

    if noise_se is not None:
        Y = Y_true + torch.randn_like(Y_true) * noise_se
    else:
        Y = Y_true
    return Y_true, Y


def generate_initial_data(
    n: int,
    base_function: BaseTestProblem,
    bounds: Tensor,
    tkwargs: Dict[str, Any],
    noise_se: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor]:
    r"""
    Generates the initial data for the experiments.
    Args:
        n: Number of training points.
        base_function: The base problem.
        bounds: The bounds to generate the training points from. `2 x d`-dim tensor.
        tkwargs: Arguments for tensors, dtype and device.
        noise_se: noise levels

    Returns:
        The train_X and train_Y_true, and train_Y. `n x d`, `n x m`, and `n x m`.
    """
    train_x = draw_sobol_samples(bounds=bounds, n=n, q=1).squeeze(-2).to(**tkwargs)
    train_obj_true, train_obj = eval_problem(
        train_x,
        base_function=base_function,
        noise_se=noise_se,
    )
    return train_x, train_obj_true, train_obj


def initialize_model(
    train_X: List[Tensor],
    train_Y: List[Tensor],
    num_objectives: int,
    noise_var: Optional[Tensor] = None,
    use_fixed_noise: bool = False,
    **kwargs,
) -> Tuple[
    Union[ExactMarginalLogLikelihood, SumMarginalLogLikelihood],
    Union[FixedNoiseGP, SingleTaskGP, ModelListGP],
]:
    r"""Constructs the model and its MLL.

    Args:
        train_x: An m-list of `n x d`-dim tensors of training inputs.
        train_y: An  m-list of `n_m x 1`-dim tensors of training outcomes.
        noise_var: A `1 x m`-dim tensor of noise levels
        use_fixed_noise: If True, assumes noise-free outcomes and uses FixedNoiseGP.
    Returns:
        The MLL and the model. Note: the model is not trained!
    """
    base_model_class = FixedNoiseGP if use_fixed_noise else SingleTaskGP
    model_kwargs = []
    d = train_X.shape[-1]
    m = train_Y.shape[1]
    if m > 1:
        for i in range(m):
            model_kwargs.append(
                {
                    "train_X": train_X,
                    "train_Y": train_Y[:, i : i + 1],
                    "covar_module": get_covar_module(d),
                }
            )
            # this means it is a constraint dimension and we standardize here
            if not i < num_objectives:
                model_kwargs[i]["outcome_transform"] = Standardize(m=1)

            if use_fixed_noise:
                if noise_var is None:
                    # noiseless
                    train_yvar = torch.full_like(train_Y[:, i : i + 1], 1e-7) * train_Y[
                        :, i : i + 1
                    ].std(dim=0).pow(2)
                else:
                    train_yvar = noise_var[i].expand(train_Y.shape[0], [1])
                model_kwargs[i]["train_Yvar"] = train_yvar
                # TODO: potentially use different priors
            else:
                model_kwargs[i]["likelihood"] = get_likelihood()

        models = [base_model_class(**model_kwargs[i]) for i in range(m)]
        model = ModelListGP(*models)
        mll = SumMarginalLogLikelihood(model.likelihood, model)
    else:
        model_kwargs = {
            "train_X": train_X,
            "train_Y": train_Y,
            "covar_module": get_covar_module(d),
        }
        if use_fixed_noise:
            train_Y = model_kwargs["train_Y"]
            if noise_var is None:
                # noiseless
                train_yvar = torch.full_like(train_Y, 1e-7) * train_Y.std(dim=0).pow(2)
            else:
                train_yvar = noise_var.expand(train_Y.shape)
            model_kwargs["train_Yvar"] = train_yvar
        else:
            model_kwargs["likelihood"] = get_likelihood()
        model = base_model_class(**model_kwargs)
        mll = ExactMarginalLogLikelihood(model.likelihood, model)
    return mll, model


def get_covar_module(d: int):
    # print(f"getting covar module")
    covar_module = ScaleKernel(
        MaternKernel(
            nu=2.5,
            ard_num_dims=d,
            lengthscale_constraint=LogTransformedInterval(1e-2, 1e2, initial_value=1.0),
        ),
        outputscale_constraint=LogTransformedInterval(1e-2, 1e2, initial_value=1),
    )
    return covar_module


def get_likelihood():
    likelihood = GaussianLikelihood(
        # NOTE implies std in [1e-2, 1e-1], shoud check for noisy experiments
        # if we allow std down to 1e-3, model fitting errors pop up for constraints
        noise_constraint=LogTransformedInterval(1e-4, 1e-2, initial_value=1e-3)
    )
    return likelihood


def get_acqf(
    label: str,
    model: GPyTorchModel,
    X_baseline: Tensor,
    train_Y: Optional[Tensor],
    base_function: BaseTestProblem,
    standardize_tf: Standardize,
    **kwargs,
) -> AcquisitionFunction:
    r"""Construct the acquisition function.

    Args:
        label: The name of the acquisition function
        model: The fitted model.
        X_baseline: The previously evaluated designs, normalized to the
            unit cube.
        train_Y: The observations at the previously evaluated points.
        kwargs: Acquisition function-specific kwargs

    Returns:
        The acquisition function.

    """
    # pulling MC samples
    if "mc_samples" in kwargs and "sampler" not in kwargs:
        # SOBOL sampler works for mc_samples <~ 21k, SOBOL_ENGINE.MAX_DIM
        kwargs["sampler"] = SobolQMCNormalSampler(
            sample_shape=torch.Size([kwargs.pop("mc_samples")]), seed=1234
        )
    if isinstance(base_function, ConstrainedBaseTestProblem):
        objective_index = 0
        # For KG, will need to use ConstrainedMCObjective
        mc_objective = GenericMCObjective(
            lambda samples, X: samples[..., objective_index],
        )
        # this is a list of functions that subselect the constraint output dimensions
        constraints = [
            lambda samples: samples[..., i + base_function.num_objectives]
            for i in range(base_function.num_constraints)
        ]
        if label == "kg":  # for KG, need to use ConstrainedMCObjective
            mc_objective = ConstrainedMCObjective(mc_objective, constraints)
        else:  # for non-KG, use generic objective and pass constraints fo acqf directly
            kwargs.setdefault("constraints", constraints)
        kwargs.setdefault("objective", mc_objective)

    if isinstance(base_function, MultiObjectiveTestProblem):
        # standardize ref point
        ref_point = standardize_tf(base_function.ref_point.unsqueeze(0))[0].squeeze(0)
        if "ehvi" in label and "nehvi" not in label:
            bd = FastNondominatedPartitioning(ref_point=ref_point, Y=train_Y)
            if label == "ehvi":
                acq_func = ExpectedHypervolumeImprovement(
                    model=model,
                    ref_point=ref_point,
                    partitioning=bd,
                    **_filter_kwargs(ExpectedHypervolumeImprovement, **kwargs),
                )
            elif label == "logehvi":
                acq_func = LogExpectedHypervolumeImprovement(
                    model=model,
                    ref_point=ref_point,
                    partitioning=bd,
                    **_filter_kwargs(LogExpectedHypervolumeImprovement, **kwargs),
                )
            elif label == "qehvi":
                acq_func = qExpectedHypervolumeImprovement(
                    model=model,
                    ref_point=ref_point,
                    partitioning=bd,
                    **_filter_kwargs(qExpectedHypervolumeImprovement, **kwargs),
                )
            elif label == "qlogehvi":
                acq_func = qLogExpectedHypervolumeImprovement(
                    model=model,
                    ref_point=ref_point,
                    partitioning=bd,
                    **_filter_kwargs(qLogExpectedHypervolumeImprovement, **kwargs),
                )
            else:
                raise NotImplementedError
        elif label == "qnehvi":
            acq_func = qNoisyExpectedHypervolumeImprovement(
                model=model,
                X_baseline=X_baseline,
                ref_point=ref_point,
                **_filter_kwargs(qNoisyExpectedHypervolumeImprovement, **kwargs),
            )
        else:
            raise NotImplementedError
    else:
        if "ei" in label:
            best_feasible_f = (
                train_Y.max()
                if not isinstance(base_function, ConstrainedBaseTestProblem)
                else get_best_feasible_f(obj=train_Y[..., [0]], cons=train_Y[..., 1:])
            )

        if label == "ei":
            acq_func = ExpectedImprovement(
                model=model,
                # best observed value
                best_f=train_Y.max(),
                **_filter_kwargs(ExpectedImprovement, **kwargs),
            )
        elif label == "qei":
            acq_func = qExpectedImprovement(
                model=model,
                best_f=best_feasible_f,
                **_filter_kwargs(qExpectedImprovement, **kwargs),
            )
        elif label == "logei":
            acq_func = LogExpectedImprovement(
                model=model,
                best_f=train_Y.max(),
                **_filter_kwargs(LogExpectedImprovement, **kwargs),
            )
        elif label == "qlogei":
            acq_func = qLogExpectedImprovement(
                model=model,
                best_f=best_feasible_f,
                **_filter_kwargs(qLogExpectedImprovement, **kwargs),
            )
        elif label == "pi":
            acq_func = ProbabilityOfImprovement(
                model=model,
                best_f=train_Y.max(),
                **_filter_kwargs(ProbabilityOfImprovement, **kwargs),
            )
        elif label == "logpi":
            acq_func = LogProbabilityOfImprovement(
                model=model,
                best_f=train_Y.max(),
                **_filter_kwargs(LogProbabilityOfImprovement, **kwargs),
            )
        elif label == "qpi":
            acq_func = qProbabilityOfImprovement(
                model=model,
                best_f=best_feasible_f,
                **_filter_kwargs(ProbabilityOfImprovement, **kwargs),
            )
        elif label == "qlogpi":
            acq_func = qLogProbabilityOfImprovement(
                model=model,
                best_f=best_feasible_f,
                **_filter_kwargs(ProbabilityOfImprovement, **kwargs),
            )
        elif label == "kg":
            num_fantasies = (
                kwargs["sampler"].sample_shape[0] if "sampler" in kwargs else 64
            )
            acq_func = qKnowledgeGradient(
                model=model,
                num_fantasies=num_fantasies,  # could this be automatic?
                **_filter_kwargs(qKnowledgeGradient, **kwargs),
            )
        elif label == "gibbon":
            candidate_set = kwargs.get(
                "candidate_set",
                torch.rand(
                    kwargs.get("candidate_set_size", 1024),
                    X_baseline.shape[-1],  # = d
                    device=train_Y.device,
                    dtype=train_Y.dtype,
                ),
            )
            acq_func = qLowerBoundMaxValueEntropy(
                model=model,
                candidate_set=candidate_set,
                **_filter_kwargs(qLowerBoundMaxValueEntropy, **kwargs),
            )
        elif label == "jes":
            standard_bounds = torch.ones(
                2,
                X_baseline.shape[-1],
                dtype=X_baseline.dtype,
                device=X_baseline.device,
            )  # 2 x d
            standard_bounds[0] = 0
            optimal_inputs, optimal_outputs = get_optimal_samples(
                model=model,
                bounds=standard_bounds,
                num_optima=kwargs.get("num_optima", 64),  # default in input_constructor
                maximize=True,
            )
            acq_func = qJointEntropySearch(
                model=model,
                optimal_inputs=optimal_inputs,
                optimal_outputs=optimal_outputs,
                **_filter_kwargs(qJointEntropySearch, **kwargs),
            )
        elif label == "cei":
            constraints = {
                i + 1: [None, 0.0] for i in range(base_function.num_constraints)
            }
            objective_index = 0
            acq_func = ConstrainedExpectedImprovement(
                model=model,
                best_f=best_feasible_f,
                objective_index=objective_index,
                constraints=constraints,
                maximize=True,
            )
        elif label == "logcei":
            # negative values imply feasibility, so there's no lower bound
            constraints = {
                i + 1: [None, 0.0] for i in range(base_function.num_constraints)
            }
            acq_func = LogConstrainedExpectedImprovement(
                model=model,
                best_f=best_feasible_f,
                objective_index=objective_index,
                constraints=constraints,
                maximize=True,
            )

    return acq_func


def get_problem(name: str, dim: Optional[int] = None, **kwargs) -> BaseTestProblem:
    r"""Initialize the test function.

    Args:
        name: The name of the test problem.
        dim: The input dimension.

    Returns:
        The test problem.
    """
    if name == "dtlz2":
        return DTLZ2(
            dim=dim, num_objectives=kwargs.get("num_objectives", 2), negate=True
        )
    elif name == "zdt1":
        return ZDT1(
            dim=dim, num_objectives=kwargs.get("num_objectives", 2), negate=True
        )
    elif name == "zdt2":
        return ZDT2(
            dim=dim, num_objectives=kwargs.get("num_objectives", 2), negate=True
        )
    elif name == "zdt3":
        return ZDT3(
            dim=dim, num_objectives=kwargs.get("num_objectives", 2), negate=True
        )
    elif name == "bc":
        return BraninCurrin(negate=True)
    elif name == "hartmann":
        return Hartmann(dim=dim, negate=True)
    elif name == "branin":
        return Branin(negate=True)
    elif name == "ackley":
        # defining problem on smaller, asymmetric domain
        bounds = [(-32.768 / 4, 32.768 / 2) for _ in range(dim)]
        return Ackley(dim=dim, negate=True, bounds=bounds)
    elif name == "rastrigin":
        # defining problem on smaller, asymmetric domain
        bounds = [(-5.12 / 2, 5.12) for _ in range(dim)]
        return Rastrigin(dim=dim, negate=True, bounds=bounds)
    elif name == "sos":
        return SumOfSquares(dim=dim, negate=True)
    elif name == "michalewicz":
        return Michalewicz(dim=dim, negate=True)
    elif name == "levy":
        return Levy(dim=dim, negate=True)
    elif name == "shekel":
        return Shekel(negate=True)  # 4-dimensional
    elif name == "styblinksi":
        return StyblinskiTang(dim=dim, negate=True)
    elif name == "pressure_vessel":
        # negating slack too, b/c problem definitions consider c(x) > 0 feasible,
        # but BoTorch constraint implementation considers c(x) < 0 feasible.
        return PressureVesselDesign(negate=True, negate_slack=True)
    elif name == "speed_reducer":
        return SpeedReducer(negate=True, negate_slack=True)
    elif name == "welded_beam":
        return WeldedBeamSO(negate=True, negate_slack=True)
    elif name == "tension_compression":
        return TensionCompressionString(negate=True, negate_slack=True)
    else:
        raise UnsupportedError(f"{name} is unsupported")


def get_best_feasible_f(obj: Tensor, cons: Tensor, allow_inf: bool = False) -> Tensor:
    # assumes negative values imply feasibility and obj is the objective.
    is_feasible = (cons <= 0).all(dim=-1)
    # print(f"{obj = }")
    # print(f"{cons = }")
    # print(f"{is_feasible = }")
    if is_feasible.any():
        return obj[is_feasible].amax(dim=0)
    elif allow_inf:  # this is important in reporting statistics
        return torch.tensor([-torch.inf], dtype=obj.dtype, device=obj.device)
    else:
        # if there's nothing feasible, default to 10 standard deviations below
        # worst observed objective value
        return obj.amin(dim=0) - 10 * obj.std(dim=0)


###################### defining log trasformed interval ######################
import math
from typing import Any, Dict, List, Optional, Union

import torch
from botorch.exceptions import UnsupportedError
from botorch.fit import fit_gpytorch_mll
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from gpytorch import settings
from gpytorch.constraints import Interval
from torch import Tensor
from torch.nn import Parameter


class LogTransformedInterval(Interval):
    """Modification of the GPyTorch interval class.

    The Interval class in GPyTorch will map the parameter to the range [0, 1] before
    applying the inverse transform. We don't want to do this when using log as an
    inverse transform. This class will skip this step and apply the log transform
    directly to the parameter values so we can optimize log(parameter) under the bound
    constraints log(lower) <= log(parameter) <= log(upper).
    """

    def __init__(self, lower_bound, upper_bound, initial_value=None):
        super().__init__(
            lower_bound=lower_bound,
            upper_bound=upper_bound,
            transform=torch.exp,
            inv_transform=torch.log,
            initial_value=initial_value,
        )

        # Save the untransformed initial value
        self.register_buffer(
            "initial_value_untransformed",
            torch.tensor(initial_value).to(self.lower_bound)
            if initial_value is not None
            else None,
        )

        if settings.debug.on():
            max_bound = torch.max(self.upper_bound)
            min_bound = torch.min(self.lower_bound)
            if max_bound == math.inf or min_bound == -math.inf:
                raise RuntimeError(
                    "Cannot make an Interval directly with non-finite bounds. Use a derived class like "
                    "GreaterThan or LessThan instead."
                )

    def transform(self, tensor):
        if not self.enforced:
            return tensor

        transformed_tensor = self._transform(tensor)
        return transformed_tensor

    def inverse_transform(self, transformed_tensor):
        if not self.enforced:
            return transformed_tensor

        tensor = self._inv_transform(transformed_tensor)
        return tensor
